import cvxpy as cp  # New
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from .base import Trainer

# drops:
#     is_upsampling: False
#     gamma: 0.5              # gamma-parameter for focal/ldam(C)
#     beta: 0.9999            # beta-parameter for cb
#     s: 1                    # s-parameter for scaling logits of ldam
# Ignoring this, update once per epoch instead


class DropsTrainer(Trainer):
    def __init__(
        self,
        config,
        model,
        logger,
        train_set,
        test_set,
        criterion,
        optimizer,
        scheduler=None,
        val_set=None,
    ):
        super().__init__(
            config,
            model,
            logger,
            train_set,
            test_set,
            criterion,
            optimizer,
            scheduler,
            val_set,
        )

        ########## Part 1: dataset info ##########
        # Add this to common parameters
        self.num_classes, self.samples_per_cls = self.find_class_info(self.train_loader)

        ########## Part 2: DRO eval info ##########
        # Add this to commone parameters
        # TODO: How do we handle logits_editing with other methods?
        self.dro_eval_parameters = self.dro_eval_init(self.num_classes)
        self.logits_editing = "drops"  # [None, 'drops', 'posthoc', 'posthoc_ce']

        ########## Part 3: DRO info ##########
        # DROPS requires a validation set
        # TODO: Drops requires a validation set
        if self.val_set is None:
            self.val_loader = self.test_loader
        self.drops_info = self.init_drops()

        # TODO: How do we save and load self.drops_info?

    def run(self):
        print("==> Start training..")
        best_acc = 0.0
        for cur_epoch in range(self.epoch):
            self.model.train()
            epoch_loss, epoch_correct, total_num = 0.0, 0.0, 0.0
            with tqdm(self.train_loader, unit="batch") as tepoch:
                for data in tepoch:
                    tepoch.set_description(f"Epoch {cur_epoch}")
                    inputs, labels, attributes, idx = self.prepare_data(data)
                    self.optimizer.zero_grad()
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, labels)
                    loss.backward()
                    self.optimizer.step()
                    correct = (outputs.argmax(1) == labels).sum().item()
                    tepoch.set_postfix(
                        loss=loss.item(),
                        accuracy=100.0 * correct / inputs.size(0),
                        lr=self.get_lr(),
                    )
                    epoch_loss += loss
                    epoch_correct += correct
                    total_num += inputs.size(0)
                    self.global_iter += 1
                    if (
                        self.global_iter % self.config["general"]["logger"]["frequency"]
                        == 0
                    ):
                        self.logger.info(
                            f"[{cur_epoch}]/[{self.epoch}], Global Iter: {self.global_iter}, Loss: {loss:.4f}, Acc: {100.0 * correct / inputs.size(0):.4f}, lr: {self.get_lr():.6f}",
                            {
                                "cur_epoch": cur_epoch,
                                "iter": self.global_iter,
                                "loss": loss.item(),
                                "Accuracy": 100.0 * correct / inputs.size(0),
                                "lr": self.get_lr(),
                            },
                        )

            #### Drops Update ####
            self.drops_updateg()

            epoch_loss /= total_num
            epoch_acc = epoch_correct / total_num * 100.0
            if self.val_set:
                _ = self.evaluate(val=True)
            test_acc = self.evaluate(val=False)

            if test_acc > best_acc:
                best_acc = test_acc
                self.save_best_model()
            print(
                f"Epoch: {cur_epoch}, Loss: {epoch_loss:.6f}, Train Acc: {epoch_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test Acc: {best_acc:.4f}"
            )
            self.logger.info(
                f"[{cur_epoch}]/[{self.epoch}], Loss: {epoch_loss:.6f}, Train Acc: {epoch_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test Acc: {best_acc:.4f}",
                {
                    "test_epoch": cur_epoch,
                    "loss": epoch_loss.item(),
                    "Train Acc": epoch_acc,
                    "Test Acc": test_acc,
                    "Best Test Acc": best_acc,
                },
            )

            if self.scheduler:
                self.scheduler.step()
            self.save_last_model()

            if cur_epoch % self.save_every_epoch == 0:
                self.save_model(f"{cur_epoch}")

    #############################
    # 👇 Drops algorithm 👇      #
    def init_drops(self):
        self.drops_parameters = {
            "metric_base": self.config["train"]["metric_base"],
            "eta_g": self.config["train"]["eta_g"],
            "eta_lambda": self.config["train"]["eta_lambda"],
            "weight_type": self.config["train"]["weight_type"],
            "g_type": self.config["train"]["g_type"],
            "re_weight_type": self.config["train"]["re_weight_type"],
        }

        # Initialize g_y to be the same as uniform | prior p(y)| 1/p(y)
        # Default is set to be 'uniform'
        if self.drops_parameters["metric_base"] == "uniform":
            g_y = [1] * self.num_classes
        elif (
            self.drops_parameters["metric_base"] == "prior"
            and self.drops_parameters["re_weight_type"] == "prior"
        ):
            g_y = [1 / i for i in self.samples_per_cls]
        elif (
            self.drops_parameters["metric_base"] == "recip_prior"
            and self.drops_parameters["re_weight_type"] == "prior"
        ):
            g_y = self.samples_per_cls
        g_y = g_y / np.sum(g_y)
        alpha_y = [1 / i for i in self.samples_per_cls]
        alpha_y = torch.tensor(alpha_y).float()
        alpha_y *= torch.sum(torch.tensor(self.samples_per_cls).float())

        lambd = 1.0
        # set the r_list to be the u in the constraint D(u, g) < delta
        r_list = g_y

        g_y = torch.tensor(g_y).float()
        r_list = torch.tensor(r_list).float()

        return {
            "g_y": g_y,
            "alpha_y": alpha_y,
            "lambd": lambd,
            "r_list": r_list,
        }

    @torch.no_grad()
    def drops_updateg(self):

        # Common variables
        ds = self.val_loader
        model = self.model
        device = self.device
        config = self.config
        self.samples_per_cls = self.samples_per_cls

        # dynamic variables
        g_y = self.drops_info["g_y"]
        lambd = self.drops_info["lambd"]
        alpha_y = self.drops_info["alpha_y"]
        r_list = self.drops_info["r_list"]

        # static drops hyperparameters
        eta_g = self.drops_parameters["eta_g"]
        eta_lambda = self.drops_parameters["eta_lambda"]
        weight_type = self.drops_parameters["weight_type"]
        g_type = self.drops_parameters["g_type"]

        # dro_eval_parameters
        delta = self.dro_eval_parameters["eps"]
        dro_div = self.dro_eval_parameters["dro_div"]

        # Initialize Basic Variables
        g_y, alpha_y, r_list = g_y.to(device), alpha_y.to(device), r_list.to(device)

        # Start updating
        loss_y_list = torch.tensor([0] * self.num_classes).float().to(device)
        num_y_list = torch.tensor([0] * self.num_classes).to(device)

        model.eval()
        with tqdm(ds, unit="batch") as tepoch:
            for it, data in enumerate(tepoch):
                images, y_true = data[0].to(device), data[1].to(device)
                y_true = y_true.long()
                y_pred = model(images)
                # Main steps for a mini-batch:
                #   Step 1: Prepare w_y (optional--for ce loss)
                #           For each class y in class list [1, ..., K],
                #           generate a sample weight w_y for the batch:
                #               [val_i for i in range(batch_size)]
                #               val_i =1 if the target for sample x_i is y;

                if weight_type == "ce":
                    # Create an identity matrix of size `num_classes`
                    eye_mat = np.eye(self.num_classes, dtype=int)
                    # Map each value in `y_true` to a row in the identity matrix
                    w_y_list = eye_mat[y_true.cpu()].T

                # Step 2: Update softmaxed of logits get model prediction (post-shift)
                #       softmax_y = softmax of logits
                #       model prediction is: argmax_y (alpha_y * softmax_y)
                y_pred_prob = F.softmax(y_pred, dim=-1)
                y_weighted_pred = y_pred_prob * alpha_y
                y_weighted_pred = y_weighted_pred / torch.sum(
                    y_weighted_pred, dim=1, keepdim=True
                )
                arg_pred = torch.argmax(y_pred_prob * alpha_y, dim=1)

                # Step 3: Calculate L_y (per class loss), L is the 0-1 loss
                #         Get the 0-1 loss for the mini-batch with sample weight w_y;
                #                achieve with 0-1 loss:
                #                  get the index for class y: idx_y
                #                  L_y = 0-1_loss(y_pred[idx_y], y_true[idx_y])

                if weight_type == "ce":
                    tmp_loss = F.cross_entropy(
                        y_weighted_pred, y_true, reduction="none"
                    )
                    for i in range(self.num_classes):
                        tmp = torch.sum(
                            torch.tensor(w_y_list[i], dtype=torch.float32).to(device)
                            * tmp_loss
                        )
                        num_y_list[i] += len(torch.where(y_true == i)[0])
                        loss_y_list[i] += tmp.item()
                else:  # 0-1 loss for L
                    acc_list = (arg_pred == y_true).int()
                    for i in range(self.num_classes):
                        idx = torch.where(y_true == i)[0]
                        num_y_list[i] += len(idx)
                        tmp_loss = (
                            1 - torch.gather(acc_list, 0, idx.unsqueeze(1)).squeeze()
                        )
                        tmp = tmp_loss.sum().float()
                        loss_y_list[i] += tmp.item()

        # Continue with accumulated loss for the whole validation set
        # loss_y_list = loss_y_list / torch.tensor(num_y_list, dtype=torch.float32)
        loss_y_list = loss_y_list / num_y_list.clone().detach()

        # Step 4: Get the Lagrangian constraint term
        # cons = lambda * (D(r, g) - delta)

        if dro_div == "l2":
            div = torch.sum(torch.square(r_list - g_y))
        elif dro_div == "l1":
            div = torch.sum(torch.abs(r_list - g_y))
        elif dro_div == "reverse-kl":
            # D(g_y, r_list)
            tmp = torch.log((r_list + 1e-12) / (g_y + 1e-12))
            div = torch.sum(r_list * tmp)
        elif dro_div == "kl":
            # D(g_y, r_list)
            tmp = torch.log((g_y + 1e-12) / (r_list + 1e-12))
            div = torch.sum(g_y * tmp)
        div = div.float()
        cons = lambd * (div - delta)

        # Step 5: Get the Lagrangian (with L_y list of size K)
        # Lagrangian = tf.reduce_sum(g_y_list * L_y_list) + cons
        lagrangian = (g_y * loss_y_list).sum().float()
        lagrangian -= cons

        # Step 6: EG update on g_y
        if g_type == "eg":
            if dro_div == "kl":
                part1 = eta_g * loss_y_list.float()
                log_r = torch.log(r_list + 1e-12).float()
                part2 = lambd * eta_g * log_r
                neu = part1 + part2
                neu += torch.log(g_y + 1e-12).float()
                den = lambd * eta_g + 1.0
                g_y_updated = torch.exp(neu / den - 1.0)
                g_y_updated /= torch.sum(g_y_updated)
            elif dro_div == "reverse-kl":
                part1 = (
                    eta_g * lambd * torch.log(r_list + 1e-12).tolist()
                )  # Really to list?
                neu = g_y.float() + part1
                den = eta_g * loss_y_list.float()
                g_y_updated = neu / den
                g_y_updated /= torch.sum(g_y_updated)
        else:
            g_y_updated = r_list * torch.exp(loss_y_list / lambd)
            g_y_updated /= torch.sum(g_y_updated)
        # Step 7: EG update on lambda
        #         lambda <- lambda - eta_lambda * gradient(Lagrangian)_lambda
        lambd_updated = lambd + eta_lambda * cons
        #       Step 8: Update alpha_y
        #               For now: alpha_y = g_y/pi_y
        alpha_y_updated = g_y_updated / torch.tensor(self.samples_per_cls).float().to(
            device
        )
        alpha_y_updated *= torch.sum(torch.tensor(self.samples_per_cls).float())

        self.drops_info["g_y"] = g_y_updated
        self.drops_info["lambd"] = lambd_updated
        self.drops_info["alpha_y"] = alpha_y_updated

        model.train()
        return
        # return g_y_updated, lambd_updated, alpha_y_updated

    # 👆 Drops algorithm 👆      #
    #############################

    #############################
    # 👇 DRO Evaluation👇        #
    def dro_eval_init(self, num_classes):
        dro_eval_parameters = {
            "tau": 1.0,
            "dro_div": "kl",
            "eps": 0.9,
        }
        if dro_eval_parameters["dro_div"] == "reverse-kl":
            eps_list = [i / 2 for i in range(21)]
            eps_list[10] = eps_list[-1]
            eps_list[11] = dro_eval_parameters["eps"]
        elif dro_eval_parameters["dro_div"] == "kl":
            upp_v = np.log(num_classes)
            eps_list = [tmp_v * upp_v / 20 for tmp_v in range(21)]
            eps_list[10] = eps_list[-1]
            eps_list[11] = 1.0
        eps_list = eps_list[:12]
        dro_eval_parameters["eps_list"] = eps_list

        return dro_eval_parameters

    def evaluate(self, val=True, second_model=False) -> list:
        if second_model:
            try:
                model_test = self.model2
            except Exception as e:
                print("There is no second model. Still testing the first model.")
                model_test = self.model
        else:
            model_test = self.model
        model_test.eval()
        correct, total_num, total_loss = 0.0, 0.0, 0.0
        loader = self.val_loader if val else self.test_loader
        evaluate_type = "Val" if val else "Test"

        ############################################
        # Step 1: Calculate the accuracy per class #
        ############################################

        num_samples_cls = [0] * self.num_classes  # New
        num_correct_cls = [0] * self.num_classes  # New

        for (
            iter,
            data,
        ) in enumerate(loader):
            inputs, labels, attributes, idx = self.prepare_data(data)
            with torch.no_grad():
                outputs = model_test(inputs)
            outputs = self.modify_logits(
                outputs, self.logits_editing
            )  # NEW Modify logits
            total_loss += self.criterion(outputs, labels).item()
            outputs = outputs.detach().cpu()
            labels = labels.detach().cpu()
            acc_list = (outputs.argmax(1) == labels).int()
            correct += (outputs.argmax(1) == labels).sum().item()
            total_num += labels.size(0)

            for j in range(self.num_classes):
                idx = torch.where(labels == j)[0]
                num_samples_cls[j] += len(idx)
                acc_j = torch.gather(acc_list, 0, idx)
                num_correct_cls[j] += acc_j.sum().item()

        acc_per_cls = [a / b for a, b in zip(num_correct_cls, num_samples_cls)]  # New
        acc_per_cls = np.array(acc_per_cls)  # New

        # acc = correct / total_num * 100

        ############################################
        # Step 2: Solve reweighting per class      #
        # Step 3: Apply weight for dro acc         #
        ############################################
        dro_acc_list = []
        eps_list = self.dro_eval_parameters["eps_list"]
        dro_div = self.dro_eval_parameters["dro_div"]

        for eps_ in eps_list:
            reweight_v = self.solve_dro_reweight(acc_per_cls, dro_div, eps_)
            dro_acc_list += [np.sum(np.multiply(reweight_v, acc_per_cls))]

        print(f"{evaluate_type} DRO EPS list: {[f'{eps_:.2f}' for eps_ in eps_list]}")
        print(
            f"{evaluate_type} DRO ACC list: {[f'{dro_acc:.2f}' for dro_acc in dro_acc_list]}"
        )

        return dro_acc_list[0]  # TODO dro_acc is a list instead of single value

    def modify_logits(self, logits, logits_editing=None):
        assert logits_editing in [None, "drops", "posthoc", "posthoc_ce"]

        if logits_editing == "drops":
            """
            function (samples_per_cls, g_y, tau, logits):
            """
            samples_per_cls = self.samples_per_cls
            g_y = self.drops_info["g_y"]
            tau = self.dro_eval_parameters["tau"]

            final_alpha_y = [a / b for a, b in zip(g_y, samples_per_cls)]
            final_alpha_y = torch.tensor(final_alpha_y, dtype=torch.float32)
            final_alpha_y *= torch.sum(
                torch.tensor(samples_per_cls, dtype=torch.float32)
            )
            logits = logits + tau * torch.log(final_alpha_y + 1e-12).to(
                self.device, non_blocking=True
            )
        elif logits_editing in ["posthoc", "posthoc_ce"]:
            """
            function (samples_per_cls, tau, logits):
            """
            samples_per_cls = self.samples_per_cls
            tau = self.dro_eval_parameters["tau"]

            spc = torch.tensor(samples_per_cls, dtype=torch.float32).to(
                self.device, non_blocking=True
            )
            spc_norm = spc / torch.sum(spc).to(self.device, non_blocking=True)
            logits = logits - tau * torch.log(spc_norm + 1e-12).to(
                self.device, non_blocking=True
            )
        elif logits_editing is None:
            pass
        else:
            raise ValueError(f"Invalid logits_editing: {logits_editing}")

        return logits

    def solve_dro_reweight(self, acc_per_cls, dro_div, eps):
        base_weight_norm = np.ones(self.num_classes) / self.num_classes
        v = cp.Variable(self.num_classes)
        v.value = v.project(base_weight_norm)
        constraints = [v >= 0, cp.sum(v) == 1]
        if dro_div == "l2":
            constraints.append(cp.sum(cp.square(v - base_weight_norm)) <= eps)
        elif dro_div == "l1":
            constraints.append(cp.sum(cp.abs(v - base_weight_norm)) <= eps)
        elif dro_div == "reverse-kl":
            # D(g, u)=sum_i u_i * [log(u_i) - log(g_i)],
            # g is the parameter v we aim to solve, u_i is the base_weight_norm.
            constraints.append(cp.sum(cp.kl_div(base_weight_norm, v)) <= eps)
        elif dro_div == "kl":
            # D(g, u)=sum_i g_i * [log(g_i) - log(u_i)],
            # g is the parameter v we aim to solve, u_i is the base_weight_norm.
            constraints.append(cp.sum(cp.kl_div(v, base_weight_norm)) <= eps)
        else:
            raise ValueError(f"Invalid dro_div: {dro_div}")

        obj = cp.Minimize(cp.sum(cp.multiply(v, acc_per_cls)))
        prob = cp.Problem(obj, constraints)
        try:
            v.value = v.project(base_weight_norm)
            prob.solve(warm_start=True)
        except cp.error.SolverError:
            prob.solve(solver="SCS", warm_start=True)

        return v.value

    # 👆 DRO Evaluation 👆       #
    #############################

    #############################
    # 👇 Add this to base 👇     #
    def find_class_info(self, loader):
        if isinstance(loader.dataset, torch.utils.data.dataset.Subset):
            targets = loader.dataset.dataset.targets  # 1D array
            selected_idx = loader.dataset.indices  # 1D List
            selected_targets = targets[selected_idx]  # 1D array
        else:
            selected_targets = loader.dataset.targets  # Selected all the targets
        num_classes = len(np.unique(selected_targets))
        samples_per_cls = np.bincount(selected_targets)

        return num_classes, samples_per_cls

    # 👇 Add this to base 👇     #
    #############################
